from typing import TYPE_CHECKING, Optional

from quixstreams.core.stream import Stream, VoidExecutor
from quixstreams.models import Topic

from .exceptions import (
    GroupByDuplicate,
    GroupByNestingLimit,
    StreamingDataFrameDuplicate,
)

if TYPE_CHECKING:
    from .dataframe import StreamingDataFrame


class DataFrameRegistry:
    """
    Helps manage multiple `StreamingDataFrames` (multi-topic `Applications`)
    and their respective repartitions.

    """

    def __init__(self) -> None:
        self._registry: dict[str, Stream] = {}
        self._topics: list[Topic] = []
        self._repartition_origins: set[str] = set()
        self._topics_to_stream_ids: dict[str, set[str]] = {}
        self._stream_ids_to_topics: dict[str, set[str]] = {}
        self._requires_time_alignment = False

    @property
    def requires_time_alignment(self) -> bool:
        """
        Check if registered StreamingDataFrames require topics to be read in timestamp-aligned way.
        That's normally required for the operations like `.concat()` and joins.
        """
        return self._requires_time_alignment

    @property
    def consumer_topics(self) -> list[Topic]:
        """
        :return: a list of Topics a consumer should subscribe to.
        """
        return self._topics

    def register_root(
        self,
        dataframe: "StreamingDataFrame",
    ):
        """
        Register a StreamingDataFrame to process data from the topic.

        The provided SDF must belong to exactly one topic.
        Only one SDF can be registered for the given topic.

        :param dataframe: StreamingDataFrame instance
        """
        topics = dataframe.topics
        if len(topics) > 1:
            raise ValueError(
                f"Expected a StreamingDataFrame with one topic, got {len(topics)}"
            )
        topic = topics[0]

        if topic.name in self._registry:
            raise StreamingDataFrameDuplicate(
                f"There is already a StreamingDataFrame using topic {topic.name}"
            )
        self._topics.append(topic)
        self._registry[topic.name] = dataframe.stream

    def register_groupby(
        self,
        source_sdf: "StreamingDataFrame",
        new_sdf: "StreamingDataFrame",
        register_new_root: bool = True,
    ):
        """
        Register a "groupby" SDF, which is one generated with `SDF.group_by()`.
        :param source_sdf: the SDF used by `sdf.group_by()`
        :param new_sdf: the SDF generated by `sdf.group_by()`.
        :param register_new_root: whether to register the new SDF as a root SDF.
        """
        if source_sdf.stream_id in self._repartition_origins:
            raise GroupByNestingLimit(
                "Subsequent (nested) `SDF.group_by()` operations are not allowed."
            )

        if new_sdf.stream_id in self._repartition_origins:
            raise GroupByDuplicate(
                "An `SDF.group_by()` operation appears to be the same as another, "
                "either from using the same column or name parameter; "
                "adjust by setting a unique name with `SDF.group_by(name=<NAME>)` "
            )

        self._repartition_origins.add(new_sdf.stream_id)

        if register_new_root:
            try:
                self.register_root(new_sdf)
            except StreamingDataFrameDuplicate:
                raise GroupByDuplicate(
                    "An `SDF.group_by()` operation appears to be the same as another, "
                    "either from using the same column or name parameter; "
                    "adjust by setting a unique name with `SDF.group_by(name=<NAME>)` "
                )

    def compose_all(
        self, sink: Optional[VoidExecutor] = None
    ) -> dict[str, VoidExecutor]:
        """
        Composes all the Streams and returns a dict of format {<topic>: <VoidExecutor>}
        :param sink: callable to accumulate the results of the execution, optional.
        :return: a {topic_name: composed} dict, where composed is a callable
        """
        executors = {}
        # Go over the registered topics with root Streams and compose them
        for topic, root_stream in self._registry.items():
            # If a root stream is connected to other roots, ".compose()" will
            # return them all.
            # Use the registered root Stream to filter them out.
            root_executors = root_stream.compose(sink=sink)
            executors[topic] = root_executors[root_stream]
        return executors

    def register_stream_id(self, stream_id: str, topic_names: list[str]):
        """
        Register a mapping between the stream_id and topic names.
        This mapping is later used to match topics to state stores
        during assignment and commits.

        The same stream id can be registered multiple times.
        :param stream_id: stream id of StreamingDataFrame
        :param topic_names: list of topics to map the stream id with
        """
        for topic_name in topic_names:
            self._topics_to_stream_ids.setdefault(topic_name, set()).add(stream_id)
            self._stream_ids_to_topics.setdefault(stream_id, set()).add(topic_name)

    def get_stream_ids(self, topic_name: str) -> list[str]:
        """
        Get a list of stream ids for the given topic name

        :param topic_name: a name of the topic
        :return: a list of stream ids
        """
        return list(self._topics_to_stream_ids[topic_name])

    def get_topics_for_stream_id(self, stream_id: str) -> list[str]:
        """
        Get a list of topics for the given stream id.

        :param stream_id: stream id
        :return: a list of topic names
        """
        return list(self._stream_ids_to_topics[stream_id])

    def require_time_alignment(self):
        """
        Require the time alignment for the topology.

        This flag is set by individual StreamingDataFrames when certain operations like
        .concat() or joins are triggered, and it will inform the application to consume
        messages in the timestamp-aligned way for the correct processing.
        """
        self._requires_time_alignment = True
